{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Using Trident with Your Own Foundation Model \n",
    "\n",
    "As more and more groups design their own foundation model, we want to offer easy tools for custom integration. This is the idea of the `CustomInferenceEncoder` from the `patch_encoder_models` module. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "import requests\n",
    "import torch\n",
    "import timm\n",
    "import torchvision.transforms as transforms\n",
    "\n",
    "from trident.patch_encoder_models import CustomInferenceEncoder\n",
    "\n",
    "# Load your custom model (eg ViT pretrained on ImageNet)\n",
    "model = timm.create_model('eva02_large_patch14_448.mim_m38m_ft_in22k_in1k', pretrained=True)\n",
    "model = model.eval()\n",
    "model.head = torch.nn.Identity()  \n",
    "\n",
    "# Set precision\n",
    "precision = torch.float16\n",
    "\n",
    "# Set transforms\n",
    "data_config = timm.data.resolve_model_data_config(model)\n",
    "eval_transforms = timm.data.create_transform(**data_config, is_training=False)\n",
    "\n",
    "# Create custom encoder\n",
    "custom_patch_encoder = CustomInferenceEncoder(\n",
    "    enc_name='my_custom_model',\n",
    "    model=model,\n",
    "    transforms=eval_transforms,\n",
    "    precision=precision\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Integrate the above model into Trident \"regular\" pipeline, e.g., using the Processor\n",
    "import os\n",
    "import torch\n",
    "from huggingface_hub import snapshot_download\n",
    "\n",
    "from trident.Processor import Processor\n",
    "from trident.segmentation_models import segmentation_model_factory\n",
    "\n",
    "OUTPUT_DIR = \"tutorial-2/\"\n",
    "DEVICE = f\"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n",
    "WSI_FNAME = '394140.svs'\n",
    "os.makedirs(OUTPUT_DIR, exist_ok=True)\n",
    "local_wsi_dir = snapshot_download(\n",
    "    repo_id=\"MahmoodLab/unit-testing\",\n",
    "    repo_type='dataset',\n",
    "    local_dir=os.path.join(OUTPUT_DIR, 'wsis'),\n",
    "    allow_patterns=[WSI_FNAME]\n",
    ")\n",
    "\n",
    "# Create processor\n",
    "processor = Processor(\n",
    "    job_dir=OUTPUT_DIR,       # Directory to store outputs\n",
    "    wsi_source=local_wsi_dir, # Directory containing WSI files\n",
    ")\n",
    "\n",
    "# Run tissue vs background segmentation\n",
    "segmentation_model = segmentation_model_factory('hest')\n",
    "processor.run_segmentation_job(\n",
    "    segmentation_model,\n",
    "    device=DEVICE\n",
    ")\n",
    "\n",
    "# Run tissue coordinate extraction (256x256 at 20x)\n",
    "processor.run_patching_job(\n",
    "    target_magnification=20,\n",
    "    patch_size=256,\n",
    "    overlap=0\n",
    ")\n",
    "\n",
    "# Run patch feature extraction using the custom encoder\n",
    "processor.run_patch_feature_extraction_job(\n",
    "    coords_dir=f'20x_256px_0px_overlap', # Make sure to change this if you changed the patching parameters\n",
    "    patch_encoder=custom_patch_encoder,\n",
    "    device=DEVICE,\n",
    "    saveas='h5',\n",
    "    batch_limit=32\n",
    ")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "trident",
   "language": "python",
   "name": "trident"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
